import sys
import numpy as np
import pickle
sys.path.append(sys.argv[-1])
from signal_udfs import peak_onset
from signal_udfs import plot_average_signals
import pandas as pd

#Input file setup: Col1 = Time, other Cols = event traces
input_data = sys.argv[1]


#Set inverting for signal to bool True or False
if sys.argv[2] == 'True':
	invert = True
else:
	invert = False
	
#Import file save_handle
save_handle = sys.argv[3]

#Import timestamps
time = []
for line in open(input_data):
	x = line.strip().split(',')
	if x[0] != '' and x[0] != 'NaN':
		time.append(float(x[0]))

#Import event traces
valid_points = []
event_index = list(range(len(x[1:])))
event_list = [[] for index in event_index]

for line_num, line in enumerate(open(input_data)):
	line = line.strip().split(',')
	if line_num <= len(time)-1:
		for point_num, point in enumerate(line[1:]):
			if point != '' and point != 'NaN':
				event_list[point_num].append(float(point))
				valid_points.append(float(point))
	else:
		break

#Line smoothing for event traces
kernel_size = 100
kernel = np.ones(kernel_size)/kernel_size
smooth_events = [[] for index in event_index]

for event_num, event in enumerate(event_list):
	smooth_events[event_num] = np.convolve(event,kernel,mode='same')

#Estimate peak onsets and times
estimated_onsets = []

for event_num, event in enumerate(smooth_events):
	temp = peak_onset(smooth_events[event_num], event_num, time, invert, save_handle)
	if temp != None:
		estimated_onsets.append(temp)

#Save returned values from peak_onset function
names = ['Event Name', 'Peak Onset Indices', 'Peak Onset Times', 'Peak Onset Amps', 'Max Peak Indices', 'Max Peak Times', 'Max Peak Amps', 'Min Valley Indices', 'Min Valley Times', 'Min Valley Amps']
onset_dict = {name:[] for name in names}

for item in estimated_onsets:
	onset_dict['Event Name'].append(item[0])
	onset_dict['Peak Onset Indices'].append(item[1])
	onset_dict['Peak Onset Times'].append(item[2])
	onset_dict['Peak Onset Amps'].append(item[3])
	onset_dict['Max Peak Indices'].append(item[4])
	onset_dict['Max Peak Times'].append(item[5])
	onset_dict['Max Peak Amps'].append(item[6])
	onset_dict['Min Valley Indices'].append(item[7])
	onset_dict['Min Valley Times'].append(item[8])
	onset_dict['Min Valley Amps'].append(item[9])
	
#Create figure of averaged event signals with error bands
plot_average_signals(time, event_list, 'Time (s)', 'DF/F0 Zscore', save_handle + '_avg_events')
plot_average_signals(time, smooth_events, 'Time (s)', 'DF/F0 Zscore', save_handle +'_avg_smooth')

#Pickle files for later use
pickle.dump(onset_dict, open('onset_dict.pkl','wb'))
pickle.dump(time, open('time.pkl','wb'))
pickle.dump(smooth_events, open('smooth_events.pkl', 'wb'))
pickle.dump(event_list,open('event_list.pkl','wb'))
